{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.,\n",
       "        0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]),\n",
       " 0)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def load_data():\n",
    "    lines = [['my', 'dog', 'has', 'flea', 'problems', 'help', 'please'],\n",
    "             ['maybe', 'not', 'take', 'him', 'to', 'dog', 'park', 'stupid'],\n",
    "             ['my', 'dalmation', 'is', 'so', 'cute', 'I', 'love', 'him'],\n",
    "             ['stop', 'posting', 'stupid', 'worthless', 'garbage'],\n",
    "             ['mr', 'licks', 'ate', 'my', 'steak', 'how', 'to', 'stop', 'him'],\n",
    "             ['quit', 'buying', 'worthless', 'dog', 'food', 'stupid']]\n",
    "\n",
    "    #求所有单词的集合\n",
    "    words = []\n",
    "    for line in lines:\n",
    "        words.extend(line)\n",
    "    words = list(set(words))\n",
    "\n",
    "    #句子数字化,出现了某单词就是1,否则是0\n",
    "    x = np.zeros((len(lines), len(words)))\n",
    "    for i in range(len(lines)):\n",
    "        for j in range(len(words)):\n",
    "            if words[j] in lines[i]:\n",
    "                x[i, j] = 1\n",
    "\n",
    "    #1是违规言论,0合法言论\n",
    "    y = np.array([0, 1, 0, 1, 0, 1])\n",
    "    return x, y\n",
    "\n",
    "\n",
    "x, y = load_data()\n",
    "x[0], y[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([-3.93182563, -3.93182563, -2.54553127, -3.23867845, -3.93182563]),\n",
       " array([-3.33220451, -3.33220451, -4.02535169, -4.02535169, -3.33220451]),\n",
       " -0.6931471805599453,\n",
       " -0.6931471805599453)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def train(x, y):\n",
    "\n",
    "    #首先求总体的违规率,等于 违规次数 / 总次数\n",
    "    p1 = y.sum() / len(y)\n",
    "    p0 = 1 - p1\n",
    "\n",
    "    #取对数概率\n",
    "    p1 = np.log(p1)\n",
    "    p0 = np.log(p0)\n",
    "\n",
    "    #根据y的值,切分x为正例和反例\n",
    "    x_1 = x[y == 1]\n",
    "    x_0 = x[y == 0]\n",
    "\n",
    "    #统计在正例中,所有词出现的次数,当然在一个句子中出现多次也只算1次\n",
    "    #也就是说,是每个词出现的句子数\n",
    "    #最后要加1是为了避免0的情况,也就是说,所有词最少出现1次\n",
    "    p1_given_word = x_1.sum(axis=0) + 1\n",
    "    #[4. 4. 4. 4. 4. 3. 2. 2. 1. 1. 1. 1. 1. 1. 1. 1.....]\n",
    "\n",
    "    #上面统计的是次数,除以总次数就等于概率了.\n",
    "    #也就是每个词,出现在正例句子中的概率\n",
    "    p1_given_word = p1_given_word / p1_given_word.sum()\n",
    "\n",
    "    #取对数概率\n",
    "    p1_given_word = np.log(p1_given_word)\n",
    "\n",
    "    #p0_given_word的计算同理\n",
    "    p0_given_word = x_0.sum(axis=0) + 1\n",
    "    p0_given_word = p0_given_word / p0_given_word.sum()\n",
    "    p0_given_word = np.log(p0_given_word)\n",
    "\n",
    "    #取对数概率\n",
    "    return p1_given_word, p0_given_word, p1, p0\n",
    "\n",
    "\n",
    "p1_given_word, p0_given_word, p1, p0 = train(x, y)\n",
    "p1_given_word[:5], p0_given_word[:5], p1, p0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#测试\n",
    "def pred(x):\n",
    "    #本来p1_given_x应该是每个单词属于p1的概率连乘, 但是因为前面取了对数, 所以这里求和就可以了\n",
    "    p1_given_x = x.dot(p1_given_word) + p1\n",
    "    p0_given_x = x.dot(p0_given_word) + p0\n",
    "\n",
    "    return 1 if p1_given_x > p0_given_x else 0\n",
    "\n",
    "\n",
    "correct = 0\n",
    "for xi, yi in zip(x, y):\n",
    "    if pred(xi) == yi:\n",
    "        correct += 1\n",
    "\n",
    "correct / len(x)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
